import json
import numpy as np
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, set_seed
from datasets import load_dataset
from tqdm import tqdm
import argparse
import os
import torch
import gc
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import seaborn as sns
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

# torch.set_float32_matmul_precision('high')
parser = argparse.ArgumentParser(description='model name as input(hugging face id)')
parser.add_argument('--model_name', type=str, help='hugging face model name')
parser.add_argument('--device_id',type=int,help='GPU ID',default = 0)
parser.add_argument('--dtype',type=int,help ="dtype for model loading",default = 1)

args = parser.parse_args()
model_name = args.model_name
d_type = args.dtype
gpu_id = args.device_id

os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)

device = 'cuda'
model_name = args.model_name
model_data_path = model_name.split('/')[1]


direc_name = f''


os.makedirs(new_direc_name,exist_ok = True)

generations_prot_save_path = new_direc_name+f'/encoding{encoding_type}_prot_p10.jsonl'
generations_mbpp_save_path = direc_name+f'/base.jsonl'

del_ids_path = direc_name+f'/encoding{encoding_type}_inf_ind.npy'
few_shot_ind_path = direc_name+f'/encoding{encoding_type}_prot_close_ind.npy'

train_feat_path = direc_name+f'/encoding{encoding_type}_arr.npy'
train_action_path = direc_name+f'/encoding{encoding_type}_labels.npy'
model_save_path = direc_name+f'/encoding{encoding_type}_model.pth'
prots_path = direc_name+f'/encoding{encoding_type}_prot.npy'

feat = np.load(train_feat_path)
lab = np.load(train_action_path)

feat_img = list(feat)
labels=list(lab)
##################new center prots###############
class DuelCNNWrapper(nn.Module):
    def __init__(self,vec_len):
        super(DuelCNNWrapper, self).__init__()
        
        self.additional_layer = nn.Sequential(
            nn.Linear(vec_len, vec_len),
            nn.InstanceNorm1d(vec_len),
            nn.ReLU()

           
        
        )

    def forward(self, x):
         
        x = self.additional_layer(x) 
        return x
        
vec_len = feat.shape[1]

model = DuelCNNWrapper(vec_len).to('cuda')
model.load_state_dict(torch.load(model_save_path))

class CustomDataset(Dataset):
    def __init__(self, img, labels, transform=None):
        self.img = img
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.img)

    def __getitem__(self, idx):
        # Fetch the image and label corresponding to the index
        img = self.img[idx]
        label = self.labels[idx]
        
        # Apply the transformation if provided
        if self.transform:
            img = self.transform(img)
        
        return img, label

        

train_dataset = CustomDataset( feat_img, labels)
indices = list(range(len(train_dataset)))

indices_shuffled = torch.randperm(len(indices)).tolist()

shuffled_dataset = torch.utils.data.Subset(train_dataset, indices_shuffled)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False,num_workers = 4,drop_last = True)
new_feat = []
new_labels= []

for img,label in tqdm(train_loader):
    
    img = img.to('cuda')
    new_arr = model(img)
    new_arr = new_arr.detach().cpu().numpy()
    # new_labels = new_labels.detach().cpu().numpy()

    for arr in new_arr:
        new_feat.append(arr)

    for lab in label:
        new_labels.append(lab)

new_arr = np.array(new_feat)
new_labels = np.array(new_labels)

inds = []

nn_human = np.load(prots_path)



for i, vec in enumerate(nn_human):
    # Compute L2 distances between vec and all feat vectors
    dists = np.linalg.norm(new_arr - vec, axis=1)
    closest_idx = np.argmin(dists)

    inds.append(closest_idx)
####################### new center ports ####################

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    output_hidden_states=True,
    torch_dtype=torch.float16,
    device_map='auto'
)
tokenizer.pad_token = tokenizer.eos_token
# model.eval()
generator = pipeline('text-generation', model=model, tokenizer=tokenizer,device_map='auto')

    
    
ds = load_dataset("ise-uiuc/Magicoder-OSS-Instruct-75K")['train']
mbpp = load_dataset("Muennighoff/mbpp", "full")

falcon_prot_file_path = 'mbpp_results/samples_mbpp_p10_falcon_new-sanitized.eval_results.json'
mbpp = load_dataset("Muennighoff/mbpp", "full")
with open(falcon_prot_file_path,'r') as file:
    falcon_prot = json.load(file)
mbpp_ref = []
for task_id in falcon_prot['eval'].keys():
    iid = int(task_id.split('/')[-1])
    mbpp_ref.append(iid-1)
mbpp_subset = mbpp['test'].select(mbpp_ref)


del_arr = np.load(del_ids_path)
new_ds  = ds.select([i for i in range(len(ds)) if i not in del_arr])

# few_shot_ind = np.load(few_shot_ind_path)
inds = np.unique(inds)

print(f"unique few shots at indexes {inds}")

# few_shot_examples = [
#     # (
#     #     ds['train']['problem'][48176],
#     #     ds['train']['solution'][48176]
#     # )
#     (
#       mbpp['test']['text'][2],
# mbpp['test']['code'][2] 
#     ),
#     (
#               mbpp['test']['text'][3],
# mbpp['test']['code'][3] 
#     ),
#     (
#             mbpp['test']['text'][4],
# mbpp['test']['code'][4]   
#     )

# ]

few_shot_examples = [(new_ds['problem'][ind], new_ds['solution'][ind]) for ind in inds]


def build_icl_prompt(few_shots, test_problem):
    prompt = ""
    for prob, sol in few_shots:
        # prompt+= f"You are an expert Python programmer, and here is your task: {prob} \n[BEGIN]\n{sol}\n[DONE]"
        prompt += f"You are an expert Python programmer, and here is your task: {prob}\n[BEGIN]\n{sol}\n[DONE]\n\n"
    
    prompt += f"You are an expert Python programmer, and here is your task: {test_problem}\n[BEGIN]\n"
    return prompt

# Use MBPP test split
# test_problems = mbpp['test']
test_problems = mbpp_subset
completions = {}


for item in tqdm(test_problems, desc="Generating completions"):
    task_id = str(item['task_id'])
    test_problem = item['text']

    # Compose ICL prompt
    icl_prompt = build_icl_prompt(few_shot_examples, test_problem)

    # Generate code (change parameters as needed for your setup)
    output = generator(
        icl_prompt,
        max_new_tokens=512,
        return_full_text=False,
        do_sample=True,
        # num_beams = 10
    #     temperature=0.6,
    # top_p=0.9,
    num_return_sequences=10
    )

    completions_list = [o['generated_text'] for o in output]
    completions[task_id] = completions_list


# new_completions = {}
# for task_id,solution in completions.items():
#     for sol in solution:
#         new_completions[task_id]=sol
    

with open(generations_prot_save_path, "w") as f:
    for task_id, output in completions.items():
        f.write(json.dumps({"task_id": f"Mbpp/{task_id}", "solution": output}) + "\n")

# with open("samples_mbpp_p10_gemma3.jsonl") as fin, open("samples_mbpp_p10_gemma3_new_org.jsonl", "w") as fout:
#     [fout.write(json.dumps({"task_id": d["task_id"], "solution": s}) + "\n") for d in map(json.loads, fin) for s in d["solution"]]

print(f"Saved all completions to {generations_prot_save_path}")

del model


def flush():
  gc.collect()
  torch.cuda.empty_cache()
  torch.cuda.reset_peak_memory_stats()

flush()
